-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay, Topi] [TF, MXNet] Unravel Index operator #5082
Conversation
…into unravel_index_op
cc: @kevinthesun, @jwfromm, @masahi Please help in reviewing. |
@@ -2509,9 +2515,7 @@ def _parse_param(self, key, value, name, shape): | |||
|
|||
array_ndim = len(np_array.shape) | |||
if array_ndim == 0: | |||
new_array = np.empty([1], dtype=np_array.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this because we want to pass the scalar as scalar only and not as a tensor of rank 1.
topi/include/topi/transform.h
Outdated
for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) { | ||
ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret); | ||
cur_val = indexdiv(indices_divs.back(), shape[v]); | ||
indices_divs.push_back(cur_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason that UnravelIndex
from topi/include/topi/detail/ravel_unravel.h
isn't used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function in this file returns all the coordinates for a given index. In compute definition we just want a coordinate for the current compute index and not for all of them. I was facing issue while extracting the current coordinate because compute index which is a Var can not be directly used to extract Expr from an array of Exprs. I had to use if_then_else construct for that. Please let me know if I am missing something here and if there is an easier way to achieve this. I could have modified the existing function to meet my purposes for example pass in the coordinate index I want to extract and return just that coordinate. Please let me know if I should implement this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok that makes sense. This implementation is good then, no need to change it.
include/tvm/relay/attrs/transform.h
Outdated
@@ -321,6 +321,12 @@ struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> { | |||
} | |||
}; // struct ArgWhereAttrs | |||
|
|||
/*! \brief Attributes used in unravel_index operators */ | |||
struct UnRavelIndexAttrs : public tvm::AttrsNode<UnRavelIndexAttrs> { | |||
TVM_DECLARE_ATTRS(UnRavelIndexAttrs, "relay.attrs.UnRavelIndexAttrs") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's any need to define an attribute type for an operator without attributes. Although argwhere
seems to do the same thing you have, other operators without attributes just don't use one (see nn.batch_flatten
as one example). I'd argue we should try to avoid defining unnecessary attrs to prevent bloat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Thanks. This is good to know. I have removed the attrs for both unravel_index and argwhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments, once fixed I'll merge this.
Thanks @maheshambule @jwfromm |
* first cut unravel_index * merge fixes * change rates to dilations * unravel_index op relay, topi, mxnet, tf * doc changes * small changes * remove empty unravel and argwhere attrs * remove empty unravel and argwhere attrs
* first cut unravel_index * merge fixes * change rates to dilations * unravel_index op relay, topi, mxnet, tf * doc changes * small changes * remove empty unravel and argwhere attrs * remove empty unravel and argwhere attrs
Adds support for unravel_index op.
NumPy Reference:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.unravel_index.html
Added support for Tensorflow and MXNet frontends.